# agent.py
"""
Defines the QLearningAgent class.
Agent state now includes a discretized version of the planner's dual variable (mu).
Includes functionality for an evaluation mode where learning and exploration are turned off.
"""

import numpy as np
import random
from collections import defaultdict
from utils import discretize # Import helper function
import config # Import configuration

class QLearningAgent:
    """
    Represents an agent learning to report values using Q-learning.
    State includes (value_bin, round_t, mu_bin).
    """
    def __init__(self, agent_id, 
                 num_report_actions=None, # Allow override, else use config
                 report_range=None,       # Allow override, else use config
                 learning_rate=None,      # Allow override, else use config (alpha)
                 discount_factor=None,    # Allow override, else use config (gamma)
                 initial_epsilon=None,    # Allow override, else use config (EPSILON_START)
                 epsilon_decay=None,      # Allow override, else use config (EPSILON_DECAY)
                 min_epsilon=None         # Allow override, else use config (EPSILON_END)
                 ):
        self.agent_id = agent_id
        # Load parameters from the global config or use provided arguments
        self.k = config.K
        self.t_max = config.T
        self.num_value_bins = config.NUM_VALUE_BINS
        
        self.num_report_actions = num_report_actions if num_report_actions is not None else config.NUM_REPORT_ACTIONS
        self.value_range = config.VALUE_RANGE # Assuming this is defined and used by discretize
        self.report_range = report_range if report_range is not None else config.REPORT_RANGE
        
        self.gamma = discount_factor if discount_factor is not None else config.GAMMA
        self.alpha = learning_rate if learning_rate is not None else config.ALPHA
        
        self.initial_epsilon = initial_epsilon if initial_epsilon is not None else config.EPSILON_START
        self.epsilon_decay_factor = epsilon_decay if epsilon_decay is not None else config.EPSILON_DECAY
        self.min_epsilon = min_epsilon if min_epsilon is not None else config.EPSILON_END
        
        self.epsilon = self.initial_epsilon

        # New parameters for mu discretization
        self.num_mu_bins = config.NUM_MU_BINS
        self.mu_range = config.MU_RANGE

        # Q-table: state is (value_bin, round_t, mu_bin_tuple), action is report_action_index
        # Using defaultdict for easier initialization
        self.q_table = defaultdict(lambda: np.zeros(self.num_report_actions))

        # Evaluation mode flag
        self.is_eval_mode = False

    def get_state(self, private_value, current_round_t, current_mu_vector):
        """
        Determines the agent's current state by discretizing continuous values.
        Args:
            private_value (float): Agent's current private value.
            current_round_t (int): Current round number (0 to T-1).
            current_mu_vector (np.ndarray): Current dual variable vector from the planner.
                                         If planner doesn't use mu, a default (e.g., [0]*COST_DIM) should be passed.
        Returns:
            tuple: The discretized state (value_bin, current_round_t, mu_bins_tuple).
        """
        value_bin = discretize(private_value, self.num_value_bins, self.value_range)
        
        # Discretize each component of the mu_vector
        # Ensure current_mu_vector is an iterable (like a list or numpy array)
        if not isinstance(current_mu_vector, (list, np.ndarray)):
            # Handle cases where mu might be a single float if COST_DIM=1, or None
            if isinstance(current_mu_vector, (int, float)):
                 current_mu_vector = [current_mu_vector] # Convert to list
            elif current_mu_vector is None: # Default if mu is not used or not available
                 current_mu_vector = [0.0] * config.COST_DIM # Assuming COST_DIM is defined in config
            else: # Fallback for unexpected types
                raise TypeError(f"current_mu_vector is of unexpected type: {type(current_mu_vector)}. Expected list or np.ndarray.")

        mu_bins_list = [discretize(val, self.num_mu_bins, self.mu_range) for val in current_mu_vector]
        mu_bins_tuple = tuple(mu_bins_list) # Convert list of bins to a tuple to be hashable for dict keys

        return (value_bin, current_round_t, mu_bins_tuple)

    def choose_action(self, state):
        """
        Chooses an action (report index) using epsilon-greedy policy.
        During evaluation mode, epsilon is forced to 0 (purely greedy).
        """
        current_epsilon = 0.0 if self.is_eval_mode else self.epsilon

        if random.random() < current_epsilon:
            # Exploration: choose a random action
            return random.randint(0, self.num_report_actions - 1)
        else:
            # Exploitation: choose the best action from Q-table
            q_values = self.q_table[state] # defaultdict handles new states by returning zeros
            if not np.any(q_values): # If all Q-values are zero (e.g., new state or all actions equally bad)
                return random.randint(0, self.num_report_actions - 1) # Random tie-breaking for all-zero
            
            # Handle ties by choosing randomly from all max indices
            max_q_value = np.max(q_values)
            best_actions = np.where(q_values == max_q_value)[0]
            return np.random.choice(best_actions)

    def update_q_table(self, state, action, reward, next_state):
        """
        Updates the Q-table using the Q-learning rule.
        No updates occur if the agent is in evaluation mode.
        """
        if self.is_eval_mode:
            return # Do not update Q-table during evaluation

        # defaultdict handles initialization of q_table[state] and q_table[next_state]
        old_value = self.q_table[state][action]

        next_max = 0 # Default for terminal state or if next_state is None
        if next_state is not None: # Check if next_state is actually provided
            # If it's the last round (state[1] is current_round_t, which is 0-indexed)
            # So, if current_round_t (state[1]) is T-1, then next_state is effectively terminal for this episode.
            # The check 'state[1] >= self.t_max - 1' is more robust if next_state is always passed
            # but if next_state is None for t=T-1, then this check is sufficient.
            # For Q-learning, next_max is 0 if next_state is a terminal state.
            # In your simulation, next_state is None if t == T-1.
            q_next_values = self.q_table[next_state]
            if np.any(q_next_values): # Ensure there are non-zero Q-values to take a max from
                 next_max = np.max(q_next_values)
            # else next_max remains 0, which is correct for an unvisited next_state

        # Q-learning update rule: Q(s,a) <- (1-alpha)Q(s,a) + alpha*(r + gamma*max_a' Q(s',a'))
        # Or, equivalently: Q(s,a) <- Q(s,a) + alpha*(r + gamma*max_a' Q(s',a') - Q(s,a))
        new_value = old_value + self.alpha * (reward + self.gamma * next_max - old_value)
        self.q_table[state][action] = new_value

    def decay_epsilon(self):
        """
        Decays the exploration rate based on the configured decay factor.
        No decay occurs if the agent is in evaluation mode.
        """
        if self.is_eval_mode:
            return # Do not decay epsilon during evaluation

        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay_factor)

    def set_evaluation_mode(self, is_eval):
        """
        Sets the agent's mode to training or evaluation.
        In evaluation mode, epsilon is effectively 0 and Q-table updates are disabled.
        """
        self.is_eval_mode = is_eval
        if self.is_eval_mode:
            print(f"Agent {self.agent_id}: Switched to EVALUATION mode (epsilon=0, no Q-updates, no epsilon decay).")
        else:
            # When switching back to training, restore epsilon to its last training value or initial_epsilon
            # For simplicity, we'll just print the mode change. The epsilon value will be what it was.
            print(f"Agent {self.agent_id}: Switched to TRAINING mode (epsilon={self.epsilon:.4f}).")
